// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef OPEN_SPIEL_EXAMPLES_SBR_BLOTTO_FICTITIOUS_PLAY_H_
#define OPEN_SPIEL_EXAMPLES_SBR_BLOTTO_FICTITIOUS_PLAY_H_

#include <random>

#include "open_spiel/abseil-cpp/absl/random/uniform_int_distribution.h"
#include "open_spiel/abseil-cpp/absl/random/uniform_real_distribution.h"
#include "open_spiel/abseil-cpp/absl/time/time.h"
#include "open_spiel/algorithms/best_response.h"
#include "open_spiel/algorithms/corr_dist.h"
#include "open_spiel/algorithms/tabular_exploitability.h"
#include "open_spiel/policy.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"

namespace open_spiel {
namespace algorithms {
namespace blotto_fp {

enum class BaseSamplerType {
  kBaseUniform,  // Uniform past policy
  kBaseLatest,
};

enum class CandidatesSamplerType {
  kCandidatesInitial,  // Initial in Blotto is the uniform random policy.
  kCandidatesUniform,  // This is uniform past policy
  kCandidatesLatest,
  kCandidatesInitialUniform,
  kCandidatesInitialLatest,
};

// Returns a policy over the action sequence (0, 1, ..., num_actions - 1)
ActionsAndProbs UniformSequencePolicy(int num_actions);

// Returns a random policy over the action sequence (0, 1, ..., num_actions - 1)
// generated by normalizing random numbers.
ActionsAndProbs RandomStatePolicy(int num_actions,
                                  const std::function<double()>& rng);

class FictitiousPlayProcess {
 public:
  FictitiousPlayProcess(std::shared_ptr<const Game> game, int seed,
                        bool randomize_initial_policies);

  // Get the marginalized joint policy from the marginalized counts being
  // maintained in weight tables (defaults to cumulative_policies_ if nullptr,
  // which is marginalized separately each iteration). Weight table is indexed
  // by (player, action).
  void GetMarginalJointPolicy(
      TabularPolicy* policy,
      const std::vector<std::vector<double>>* weight_table = nullptr) const;

  // Get the marginalized joint policy by marginalizing the empirical joint
  // policy.
  void GetMarginalJointPolicyFromEmpirical(TabularPolicy* policy) const;

  void FullFPIteration();
  void SFPIteration(double lambda);
  void SBRIteration(int num_base_samples, int num_candidates);
  void IBRIteration();
  void MaxEntIBRIteration();
  void BRPIIteration(BaseSamplerType base_sampling,
                     CandidatesSamplerType candidates_sampling,
                     int num_base_samples, int num_candidates, int brpi_N);

  TabularPolicy GetMarginalJointPolicy() {
    GetMarginalJointPolicy(&joint_policy_);
    return joint_policy_;
  }

  TabularPolicy GetLatestPolicy() const;

  absl::Duration TotalTime() const { return total_time_; }

  double NashConv() const;
  double CCEDist() const;

 private:
  void InitPolicies();
  void InitPoliciesRandom();
  int Iterations() const { return iterations_; }
  std::vector<double> Softmax(const std::vector<double>& values,
                              double lambda) const;

  int JointActionToIndex(const std::vector<Action>& joint_action) const;
  std::vector<Action> IndexToJointAction(int index) const;

  NormalFormCorrelationDevice GetEmpiricalJointPolicy() const;

  // Add appropriate weights given each players' (potentially mixed) policy
  void UpdateCumulativeJointPolicy(
      const std::vector<std::vector<double>>& policies);
  void UpdateCumulativeJointPolicySampled(
      const std::vector<std::vector<double>>& policies, int num_samples);

  void AddWeight(ActionsAndProbs* policy, Action action, double weight) const;
  Action BestResponseAgainstEmpiricalJointPolicy(
      Player player, std::vector<double>* values = nullptr);
  Action BestResponseAgainstEmpiricalMarginalizedPolicies(
      Player player, std::vector<double>* values = nullptr);

  void CheckJointUtilitiesCache();

  std::vector<Action> SampleBaseProfile(BaseSamplerType sampler_type);
  Action SampleCandidate(Player player, CandidatesSamplerType sampler_type);
  std::vector<std::vector<Action>> SampleBaseProfiles(
      BaseSamplerType sampler_type, int num_base_samples);
  Action GetBestCandidate(Player player,
                          const std::vector<std::vector<Action>>& base_samples,
                          int num_candidates,
                          CandidatesSamplerType sampler_type);

  std::mt19937 rng_;
  absl::uniform_real_distribution<double> dist_;

  std::shared_ptr<const Game> game_;
  int num_players_;
  int num_actions_;
  std::vector<std::string> infostate_strings_;

  int iterations_;
  TabularPolicy joint_policy_;
  std::vector<std::vector<double>> cumulative_policies_;
  std::vector<std::unique_ptr<TabularBestResponse>> best_response_computers_;

  int num_joint_actions_;

  // Histogram of sampled joint actions.
  std::vector<double> current_joint_policy_counts_;
  // Each player's policy: time step by player
  std::vector<std::vector<ActionsAndProbs>> past_policies_;

  // Joint average strategy. Index is an encoding of the joint action in base
  // NumDistinctActions (so this vector has size NumDistinctActions^players).
  std::vector<double> cumulative_joint_policy_;

  // Player by joint index
  std::vector<std::vector<double>> cached_joint_utilities_;

  absl::Duration total_time_;
};

}  // namespace blotto_fp
}  // namespace algorithms
}  // namespace open_spiel

#endif  // OPEN_SPIEL_EXAMPLES_SBR_BLOTTO_FICTITIOUS_PLAY_H_
